{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ecc18f12-db18-46c1-9bf4-7dc85d7c59c8",
   "metadata": {},
   "source": [
    "## Adding Web Search To Your LLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d90723f-17e9-4697-98e7-052f8613bf45",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "from dotenv import load_dotenv\n",
    "_ = load_dotenv()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "adf60829-5347-4e78-88ad-d720e4f370f0",
   "metadata": {},
   "source": [
    "#### Asking Your LLM For The Latest Information"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "051e1fb4-7730-4458-ad0b-c63eea7781e5",
   "metadata": {
    "height": 149
   },
   "outputs": [],
   "source": [
    "from utils import query_raven\n",
    "question = \"Hey, can you tell me more about this R1 thing that was announced by Rabbit? \"\n",
    "\n",
    "no_function_calling_prompt = \\\n",
    "f\"\"\"\n",
    "<s> [INST] {question} [/INST]\n",
    "\"\"\"\n",
    "query_raven(no_function_calling_prompt)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f1d2c1f5-78fc-4ee1-94f3-23e883f463ee",
   "metadata": {},
   "source": [
    "#### Providing Up To Date Information"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f6ca927-0e66-4db4-9746-961de1ef81d1",
   "metadata": {
    "height": 370
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "def do_web_search(full_user_prompt : str, num_results : int = 5):\n",
    "    API_URL = f'{os.getenv(\"DLAI_TAVILY_BASE_URL\", \"https://api.tavily.com\")}/search'\n",
    "    payload = \\\n",
    "    {\n",
    "      \"api_key\": os.environ[\"TAVILY_API_KEY\"],\n",
    "      \"query\": full_user_prompt,\n",
    "      \"search_depth\": \"basic\",\n",
    "      \"include_answer\": False,\n",
    "      \"include_images\": False,\n",
    "      \"include_raw_content\": False,\n",
    "      \"max_results\": num_results,\n",
    "      \"include_domains\": [],\n",
    "      \"exclude_domains\": []\n",
    "    }\n",
    "    import requests\n",
    "    response = requests.post(API_URL, json=payload)\n",
    "    response = response.json()\n",
    "    all_results = \"\\n\\n\".join(item[\"content\"] for item in response[\"results\"])\n",
    "    return all_results"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b54b8d1e-d084-4e22-8b42-a15510e4149d",
   "metadata": {},
   "source": [
    "#### Calling Raven"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8910dd2e-ed87-4883-8508-48a4718ee5aa",
   "metadata": {
    "height": 285
   },
   "outputs": [],
   "source": [
    "function_calling_prompt = \\\n",
    "\"\"\"\n",
    "Function:\n",
    "def do_web_search(full_user_prompt : str, num_results : int = 5):\n",
    "    '''\n",
    "    Searches the web for the user question.\n",
    "    '''\n",
    "\n",
    "Example:\n",
    "User Query: What is the oldest capital in the world?\n",
    "Call: do_web_search(full_user_prompt=\"oldest capital\")\n",
    "\n",
    "User Query: {query}<human_end>\n",
    "\"\"\"\n",
    "fc_result = query_raven(function_calling_prompt.format(query=question))\n",
    "print (fc_result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d25889b3-b7db-4144-aebd-b038717d4d1f",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "result = eval(fc_result)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a13e640c-e6c3-4103-bd36-60f5ef87c39a",
   "metadata": {
    "height": 200
   },
   "outputs": [],
   "source": [
    "full_prompt = \\\n",
    "f\"\"\"\n",
    "<s> [INST]\n",
    "{result}\n",
    "\n",
    "Use the information above to answer the following question concisely.\n",
    "\n",
    "Question:\n",
    "{question} [/INST]\n",
    "\"\"\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0de0eab-230b-4951-896e-719e99740ec6",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "grounded_response = query_raven(full_prompt.format(question = question))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57d01c0b-4327-4b96-b361-4bde9c34a266",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "print (grounded_response)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a01059e6-4f7e-4158-9411-c5241888d90f",
   "metadata": {},
   "source": [
    "## Chatting With Your SQL Database\n",
    "> Note below: The database values are randomly generated so your values may differ from those in the video.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d4498e7-2321-4217-9b5d-676e66123d37",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "from utils import create_random_database\n",
    "create_random_database()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c24db79-d9bd-4850-861b-b246938e45a6",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "question = \"What is the most expensive item we currently sell?\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0b9d650-daba-4374-a6d1-5d8f0a382a8f",
   "metadata": {
    "height": 455
   },
   "outputs": [],
   "source": [
    "from utils import execute_sql, query_raven\n",
    "\n",
    "schema = \\\n",
    "\"\"\"\n",
    "CREATE TABLE IF NOT EXISTS toys (\n",
    "    id INTEGER PRIMARY KEY,\n",
    "    name TEXT,\n",
    "    price REAL\n",
    ");\n",
    "\"\"\"\n",
    "\n",
    "raven_prompt = \\\n",
    "f'''\n",
    "Function:\n",
    "def execute_sql(sql_code : str):\n",
    "  \"\"\"\n",
    "  Runs sql code for a company internal database\n",
    "  \"\"\"\n",
    "\n",
    "Schema: {schema}\n",
    "User Query: {question}\n",
    "'''\n",
    "\n",
    "output = query_raven(raven_prompt)\n",
    "print (f\"LLM's function call: {output}\")\n",
    "database_result = eval(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34109001-b30e-485a-a917-8639f5c6f241",
   "metadata": {
    "height": 217
   },
   "outputs": [],
   "source": [
    "full_prompt = \\\n",
    "f\"\"\"\n",
    "<s> [INST]\n",
    "{database_result}\n",
    "\n",
    "Use the information above to answer the following question concisely.\n",
    "\n",
    "Question:\n",
    "{question} [/INST]\n",
    "\"\"\"\n",
    "grounded_response = query_raven(full_prompt)\n",
    "print (grounded_response)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ace25e8-4526-4b10-97e4-4829e6a561fa",
   "metadata": {},
   "source": [
    "### Safer Interactions With Databases\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "554765e6-dcf4-4c79-b23e-3f78f0a2cf41",
   "metadata": {
    "height": 166
   },
   "outputs": [],
   "source": [
    "import sqlite3\n",
    "import random\n",
    "\n",
    "# Internal database name setting\n",
    "DB_NAME = 'toy_database.db'\n",
    "\n",
    "# Connect to the database\n",
    "def connect_db():\n",
    "    return sqlite3.connect(DB_NAME)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58cfcc25-6324-46b4-9b69-dcce24106f66",
   "metadata": {
    "height": 115
   },
   "outputs": [],
   "source": [
    "# List all toys\n",
    "def list_all_toys():\n",
    "    with connect_db() as conn:\n",
    "        cursor = conn.execute('SELECT * FROM toys')\n",
    "        return cursor.fetchall()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0772168-ef5e-41bb-9181-0c38f896055d",
   "metadata": {
    "height": 132
   },
   "outputs": [],
   "source": [
    "# Find toy by name prefix\n",
    "def find_toy_by_prefix(prefix):\n",
    "    with connect_db() as conn:\n",
    "        query = 'SELECT * FROM toys WHERE name LIKE ?'\n",
    "        cursor = conn.execute(query, (prefix + '%',))\n",
    "        return cursor.fetchall()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cf610e1-6384-405c-a81c-00f1f8443008",
   "metadata": {
    "height": 115
   },
   "outputs": [],
   "source": [
    "# Find toys in a price range\n",
    "def find_toys_in_price_range(low_price, high_price):\n",
    "    with connect_db() as conn:\n",
    "        query = 'SELECT * FROM toys WHERE price BETWEEN ? AND ?'\n",
    "        cursor = conn.execute(query, (low_price, high_price))\n",
    "        return cursor.fetchall()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76e38d51-88c1-45b3-aa92-99df03fa7b49",
   "metadata": {
    "height": 115
   },
   "outputs": [],
   "source": [
    "# Get a random selection of toys\n",
    "def get_random_toys(count=5):\n",
    "    with connect_db() as conn:\n",
    "        cursor = conn.execute('SELECT * FROM toys')\n",
    "        all_toys = cursor.fetchall()\n",
    "        return random.sample(all_toys, min(count, len(all_toys)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92a358df-c160-45ca-9b3e-e7c4602648c4",
   "metadata": {
    "height": 115
   },
   "outputs": [],
   "source": [
    "# Function to get the most expensive toy\n",
    "def get_most_expensive_toy(count=1):\n",
    "    with connect_db() as conn:\n",
    "        cursor = conn.execute(f'SELECT * FROM toys ORDER BY price DESC LIMIT {count}')\n",
    "        return cursor.fetchone()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a625c123-0bb6-423d-b690-4851f6f70a80",
   "metadata": {
    "height": 98
   },
   "outputs": [],
   "source": [
    "# Function to get the cheapest toy\n",
    "def get_cheapest_toy(count=1):\n",
    "    with connect_db() as conn:\n",
    "        cursor = conn.execute('SELECT * FROM toys ORDER BY price ASC LIMIT {count}')\n",
    "        return cursor.fetchone()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b507eb02-6779-4742-b0b2-48fb2fcaba1a",
   "metadata": {
    "height": 1050
   },
   "outputs": [],
   "source": [
    "raven_prompt = \\\n",
    "f'''\n",
    "Function:\n",
    "def list_all_toys():\n",
    "    \"\"\"\n",
    "    Retrieves a list of all toys from the database. This function does not take any parameters.\n",
    "    Returns: A list of tuples, where each tuple represents a toy with all its attributes (id, name, price).\n",
    "    \"\"\"\n",
    "\n",
    "Function:\n",
    "def find_toy_by_prefix(prefix):\n",
    "    \"\"\"\n",
    "    Searches for and retrieves toys whose names start with a specified prefix.\n",
    "    Parameters:\n",
    "    - prefix (str): The prefix to search for in toy names.\n",
    "    Returns: A list of tuples, where each tuple represents a toy that matches the prefix criteria.\n",
    "    \"\"\"\n",
    "\n",
    "Function:\n",
    "def find_toys_in_price_range(low_price, high_price):\n",
    "    \"\"\"\n",
    "    Finds and returns toys within a specified price range.\n",
    "    Parameters:\n",
    "    - low_price (float): The lower bound of the price range.\n",
    "    - high_price (float): The upper bound of the price range.\n",
    "    Returns: A list of tuples, each representing a toy whose price falls within the specified range.\n",
    "    \"\"\"\n",
    "\n",
    "Function:\n",
    "def get_random_toys():\n",
    "    \"\"\"\n",
    "    Selects and returns a random set of toys from the database, simulating a \"featured toys\" list.\n",
    "\n",
    "    Returns: A list of tuples, each representing a randomly selected toy. The number of toys returned is up to the specified count.\n",
    "    \"\"\"\n",
    "\n",
    "Function:\n",
    "def get_most_expensive_toy(count : int):\n",
    "    \"\"\"\n",
    "    Retrieves the most expensive toy from the database.\n",
    "    This function does not take any parameters.\n",
    "\n",
    "    Returns: A tuple representing the most expensive toy, including its id, name, and price.\n",
    "    \"\"\"\n",
    "\n",
    "Function:\n",
    "def get_cheapest_toy(count : int):\n",
    "    \"\"\"\n",
    "    Finds and retrieves the cheapest toy in the database.\n",
    "    This function does not take any parameters.\n",
    "\n",
    "    Returns: A tuple representing the cheapest toy, including its id, name, and price.\n",
    "    \"\"\"\n",
    "\n",
    "User Query: {question}<human_end>\n",
    "\n",
    "'''\n",
    "\n",
    "output = query_raven(raven_prompt)\n",
    "print (output)\n",
    "results = eval(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cda6a9e-454e-47af-9f01-cbd872ce1815",
   "metadata": {
    "height": 217
   },
   "outputs": [],
   "source": [
    "full_prompt = \\\n",
    "f\"\"\"\n",
    "<s> [INST]\n",
    "{database_result}\n",
    "\n",
    "Use the information above to answer the following question in a single sentence.\n",
    "\n",
    "Question:\n",
    "{question} [/INST]\n",
    "\"\"\"\n",
    "grounded_response = query_raven(full_prompt)\n",
    "print (grounded_response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1267ca30-e928-4655-af6b-66aabcf9e639",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
